"""Load/unload cups to/from dishwasher."""
from abc import ABC

import numpy as np
from pyquaternion import Quaternion

from bigym.bigym_env import BiGymEnv, MAX_DISTANCE_FROM_TARGET
from bigym.const import HandSide
from bigym.envs.props.dishwasher import Dishwasher
from bigym.envs.props.cabintets import BaseCabinet, WallCabinet
from bigym.envs.props.tableware import Mug
from bigym.utils.env_utils import get_random_sites
from bigym.utils.physics_utils import distance


TABLE_1_POS = np.array([1, 0, 0])
TABLE_1_ROT = np.array([0, 0, -np.pi / 2])
TABLE_2_POS = np.array([1, -0.6, 0])
TABLE_2_ROT = np.array([0, 0, -np.pi / 2])

DISHWASHER_POS = np.array([1, 0, 0])
DISHWASHER_ROT = np.array([0, 0, -np.pi / 2])


class _DishwasherCupsEnv(BiGymEnv, ABC):
    """Base cups environment."""

    _DEFAULT_ROBOT_POS = np.array([0, -0.6, 1])

    _CUPS_COUNT = 2

    def _initialize_env(self):
        self.cabinet_1: BaseCabinet = BaseCabinet(self._mojo, walls_enable=False)
        self.cabinet_2: BaseCabinet = BaseCabinet(self._mojo, panel_enable=True)
        self.dishwasher: Dishwasher = Dishwasher(self._mojo)

        self.cabinet_1.body.set_position(TABLE_1_POS)
        self.cabinet_1.body.set_euler(TABLE_1_ROT)
        self.cabinet_2.body.set_position(TABLE_2_POS)
        self.cabinet_2.body.set_euler(TABLE_2_ROT)
        self.dishwasher.body.set_position(DISHWASHER_POS)
        self.dishwasher.body.set_euler(DISHWASHER_ROT)

        self.cups = [Mug(self._mojo) for _ in range(self._CUPS_COUNT)]

    def _fail(self) -> bool:
        if (
            distance(self._robot.pelvis, self.dishwasher.body)
            > MAX_DISTANCE_FROM_TARGET
        ):
            return True
        for cup in self.cups:
            if cup.is_colliding(self.floor):
                return True
        return False

    def _on_reset(self):
        self.dishwasher.set_state(door=1, bottom_tray=0, middle_tray=1)


class DishwasherUnloadCups(_DishwasherCupsEnv):
    """Unload cups from dishwasher task."""

    _SITES_STEP = 3
    _SITES_SLICE = 3

    _CUPS_ROT_X = np.deg2rad(180)
    _CUPS_ROT_Z = np.deg2rad(90)
    _CUPS_ROT_BOUNDS = np.deg2rad(5)
    _CUPS_POS = np.array([0, -0.05, 0.05])
    _CUPS_STEP = np.array([0.115, 0, 0])

    def _get_task_privileged_obs_space(self):
        return {}

    def _get_task_privileged_obs(self):
        return {}

    def _success(self) -> bool:
        for cup in self.cups:
            if not (
                cup.is_colliding(self.cabinet_1.counter)
                or cup.is_colliding(self.cabinet_2.counter)
            ):
                return False
            for side in HandSide:
                if self.robot.is_gripper_holding_object(cup, side):
                    return False
        return True

    def _on_reset(self):
        super()._on_reset()
        sites = self.dishwasher.tray_middle.site_sets[0]
        sites = get_random_sites(
            sites, len(self.cups), self._SITES_STEP, self._SITES_SLICE
        )
        for site, cup in zip(sites, self.cups):
            quat = Quaternion(axis=[1, 0, 0], angle=self._CUPS_ROT_X)
            angle = np.random.uniform(-self._CUPS_ROT_BOUNDS, self._CUPS_ROT_BOUNDS)
            quat *= Quaternion(axis=[0, 0, 1], angle=self._CUPS_ROT_Z + angle)
            cup.body.set_quaternion(quat.elements, True)
            pos = site.get_position()
            pos += self._CUPS_POS
            cup.body.set_position(pos, True)


class DishwasherUnloadCupsLong(DishwasherUnloadCups):
    """Unload cup from dishwasher in wall cabinet task."""

    _CUPS_COUNT = 1
    _SITES_SLICE = 2

    _TOLERANCE = 0.1

    def _initialize_env(self):
        super()._initialize_env()
        self.cabinet = WallCabinet(self._mojo, glass_doors_enable=True)
        self.cabinet.body.set_position(TABLE_2_POS)
        self.cabinet.body.set_euler(TABLE_2_ROT)

    def _success(self) -> bool:
        if not np.allclose(self.dishwasher.get_state(), 0, atol=self._TOLERANCE):
            return False
        if not np.allclose(self.cabinet.get_state(), 0, atol=self._TOLERANCE):
            return False
        for cup in self.cups:
            if not cup.is_colliding(self.cabinet.shelf_bottom):
                return False
            for side in HandSide:
                if self.robot.is_gripper_holding_object(cup, side):
                    return False
        return True


class DishwasherLoadCups(_DishwasherCupsEnv):
    """Load cups to dishwasher task."""

    _CUPS_POS = np.array([0.6, -0.6, 1])
    _CUPS_POS_STEP = np.array([0, 0.15, 0])
    _CUPS_POS_BOUNDS = np.array([0.05, 0.02, 0])
    _CUPS_ROT_X = np.deg2rad(180)
    _CUPS_ROT_Z = np.deg2rad(180)
    _CUPS_ROT_BOUNDS = np.deg2rad(30)

    def _get_task_privileged_obs_space(self):
        return {}

    def _get_task_privileged_obs(self):
        return {}

    def _success(self) -> bool:
        for cup in self.cups:
            if not cup.is_colliding(self.dishwasher.tray_middle.colliders):
                return False
            for side in HandSide:
                if self.robot.is_gripper_holding_object(cup, side):
                    return False
        return True

    def _on_reset(self):
        super()._on_reset()
        for i, cup in enumerate(self.cups):
            quat = Quaternion(axis=[1, 0, 0], angle=self._CUPS_ROT_X)
            angle = np.random.uniform(-self._CUPS_ROT_BOUNDS, self._CUPS_ROT_BOUNDS)
            quat *= Quaternion(axis=[0, 0, 1], angle=self._CUPS_ROT_Z + angle)
            cup.body.set_quaternion(quat.elements, True)
            pos = self._CUPS_POS + i * self._CUPS_POS_STEP
            pos += np.random.uniform(-self._CUPS_POS_BOUNDS, self._CUPS_POS_BOUNDS)
            cup.body.set_position(pos, True)
